{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "af176fb5-fb00-4ec9-9b08-acf78e6566d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import graphviz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b56ff20f-7a4e-46d3-98f6-6c17db27d9d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ScalarTmp:\n",
    "    def __init__(self,values,prevs = [],op = None,label = None):\n",
    "        self.values = values\n",
    "        self.prevs = prevs\n",
    "        self.op = op\n",
    "        self.label = label\n",
    "\n",
    "    def __add__(self,other):\n",
    "        # 定义加法运算\n",
    "        values = self.values + other.values\n",
    "        output = ScalarTmp(values,prevs=[self,other],op=\"+\")\n",
    "        return output\n",
    "\n",
    "    def __mul__(self,other):\n",
    "        # 定义乘法法运算\n",
    "        values = self.values * other.values\n",
    "        output = ScalarTmp(values,prevs=[self,other],op=\"*\")\n",
    "        return output\n",
    "    def __repr__(self):\n",
    "        # 打印类的信息\n",
    "        return f'{self.values} | {self.op} | {self.label}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d0e5abff-c984-4e81-99ac-8b251fe972b9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3.0 | + | None\n"
     ]
    }
   ],
   "source": [
    "a = ScalarTmp(1.0,label=\"a\")\n",
    "b = ScalarTmp(2.0,label=\"b\")\n",
    "c = a + b\n",
    "print(c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "fc801c67-a1d1-42d5-96ad-85e4baca47a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from graphviz import Digraph\n",
    "\n",
    "def _trace(root):\n",
    "    # 遍历计算图中的所有点与边\n",
    "    nodes, edges = set(), set()\n",
    "    def _build(v):\n",
    "        if v not in nodes:\n",
    "            nodes.add(v)\n",
    "            for prev in v.prevs:\n",
    "                edges.add((prev,v))\n",
    "                _build(prev)\n",
    "    _build(root)\n",
    "    return nodes,edges\n",
    "\n",
    "def draw_graph(root,direction = 'forward'):\n",
    "    nodes,edges = _trace(root)\n",
    "    rankdir = 'BT' if direction == 'forward' else 'TB'\n",
    "    graph = Digraph(format=\"svg\",graph_attr={'rankdir':rankdir})\n",
    "    # 画点\n",
    "    for node in nodes:\n",
    "        label = node.label if node.op is None else node.op\n",
    "        node_attr = f'{{grad = {node.grad} | value = {node.values:.2f} | {label} }}'\n",
    "        uid = str(id(node))\n",
    "        graph.node(name = uid,label = node_attr, shape = 'record')\n",
    "    # 画边\n",
    "    for edge in edges:\n",
    "        id1 = str(id(edge[0]))\n",
    "        id2 = str(id(edge[1]))\n",
    "        graph.edge(id1,id2)\n",
    "    return graph\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "4dc826e3-da86-48fb-b0a2-b30e6e1ffa05",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.50.0 (0)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"94pt\" height=\"78pt\"\n",
       " viewBox=\"0.00 0.00 94.00 78.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 74)\">\n",
       "<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-74 90,-74 90,4 -4,4\"/>\n",
       "<!-- 2311754063504 -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>2311754063504</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"0,-0.5 0,-69.5 86,-69.5 86,-0.5 0,-0.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"43\" y=\"-54.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad = 9.0</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"0,-46.5 86,-46.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"43\" y=\"-31.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 4.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"0,-23.5 86,-23.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"43\" y=\"-8.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">c</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x21a3f435110>"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "draw_graph(c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "10addfab-1bf1-45f9-af2e-7e9609b4745f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.50.0 (0)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"302pt\" height=\"290pt\"\n",
       " viewBox=\"0.00 0.00 302.00 290.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 286)\">\n",
       "<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-286 298,-286 298,4 -4,4\"/>\n",
       "<!-- 2311752889552 -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>2311752889552</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"104,-0.5 104,-69.5 190,-69.5 190,-0.5 104,-0.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"147\" y=\"-54.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad = 0.0</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"104,-46.5 190,-46.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"147\" y=\"-31.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"104,-23.5 190,-23.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"147\" y=\"-8.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">a</text>\n",
       "</g>\n",
       "<!-- 2311752601424 -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>2311752601424</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"52,-106.5 52,-175.5 138,-175.5 138,-106.5 52,-106.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"95\" y=\"-160.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad = 0.0</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"52,-152.5 138,-152.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"95\" y=\"-137.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 4.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"52,-129.5 138,-129.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"95\" y=\"-114.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">*</text>\n",
       "</g>\n",
       "<!-- 2311752889552&#45;&gt;2311752601424 -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>2311752889552&#45;&gt;2311752601424</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M130.12,-69.77C125.71,-78.57 120.91,-88.18 116.31,-97.38\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"113.14,-95.89 111.8,-106.4 119.4,-99.02 113.14,-95.89\"/>\n",
       "</g>\n",
       "<!-- 2311752883664 -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>2311752883664</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"156,-106.5 156,-175.5 242,-175.5 242,-106.5 156,-106.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"199\" y=\"-160.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad = 0.0</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"156,-152.5 242,-152.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"199\" y=\"-137.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 3.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"156,-129.5 242,-129.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"199\" y=\"-114.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">+</text>\n",
       "</g>\n",
       "<!-- 2311752889552&#45;&gt;2311752883664 -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>2311752889552&#45;&gt;2311752883664</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M163.88,-69.77C168.29,-78.57 173.09,-88.18 177.69,-97.38\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"174.6,-99.02 182.2,-106.4 180.86,-95.89 174.6,-99.02\"/>\n",
       "</g>\n",
       "<!-- 2311752877776 -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>2311752877776</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"101,-212.5 101,-281.5 193,-281.5 193,-212.5 101,-212.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"147\" y=\"-266.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad = 0.0</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"101,-258.5 193,-258.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"147\" y=\"-243.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 12.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"101,-235.5 193,-235.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"147\" y=\"-220.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">*</text>\n",
       "</g>\n",
       "<!-- 2311752890832 -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>2311752890832</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"208,-0.5 208,-69.5 294,-69.5 294,-0.5 208,-0.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"251\" y=\"-54.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad = 0.0</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"208,-46.5 294,-46.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"251\" y=\"-31.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 2.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"208,-23.5 294,-23.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"251\" y=\"-8.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">b</text>\n",
       "</g>\n",
       "<!-- 2311752890832&#45;&gt;2311752883664 -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>2311752890832&#45;&gt;2311752883664</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M234.12,-69.77C229.71,-78.57 224.91,-88.18 220.31,-97.38\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"217.14,-95.89 215.8,-106.4 223.4,-99.02 217.14,-95.89\"/>\n",
       "</g>\n",
       "<!-- 2311752601424&#45;&gt;2311752877776 -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>2311752601424&#45;&gt;2311752877776</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M111.88,-175.77C116.29,-184.57 121.09,-194.18 125.69,-203.38\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"122.6,-205.02 130.2,-212.4 128.86,-201.89 122.6,-205.02\"/>\n",
       "</g>\n",
       "<!-- 2311752886672 -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>2311752886672</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"0,-0.5 0,-69.5 86,-69.5 86,-0.5 0,-0.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"43\" y=\"-54.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad = 0.0</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"0,-46.5 86,-46.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"43\" y=\"-31.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 4.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"0,-23.5 86,-23.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"43\" y=\"-8.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">c</text>\n",
       "</g>\n",
       "<!-- 2311752886672&#45;&gt;2311752601424 -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>2311752886672&#45;&gt;2311752601424</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M59.88,-69.77C64.29,-78.57 69.09,-88.18 73.69,-97.38\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"70.6,-99.02 78.2,-106.4 76.86,-95.89 70.6,-99.02\"/>\n",
       "</g>\n",
       "<!-- 2311752883664&#45;&gt;2311752877776 -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>2311752883664&#45;&gt;2311752877776</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M182.12,-175.77C177.71,-184.57 172.91,-194.18 168.31,-203.38\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"165.14,-201.89 163.8,-212.4 171.4,-205.02 165.14,-201.89\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x21a3f35b410>"
      ]
     },
     "execution_count": 77,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a=ScalarTmp(1.0,label='a')\n",
    "b=ScalarTmp(2.0,label='b')\n",
    "c=ScalarTmp(4.0,label='c')\n",
    "d = a + b\n",
    "e = c * a\n",
    "f = d*e\n",
    "draw_graph(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c2f304d-fc9e-4489-98b7-89442a579c52",
   "metadata": {},
   "source": [
    "## 链式法则与反向传播\n",
    "\n",
    "$$\n",
    "\\frac{\\partial{y}}{\\partial{a}} = \\sum_{i=1}^{i=n} \\frac{\\partial{y}}{\\partial{o_i}} \\times \\frac{\\partial{o_i}}{\\partial{a}}\n",
    "$$\n",
    "因此需要不重复的遍历所有的$o_i$\n",
    "因此一个节点需要添加两个属性,一个是局部的偏导(作为一个字典),另外是一个全局的偏导(浮点数)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "756db043-8d51-4d5f-b915-8d1c1db450aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ScalarTmp:\n",
    "    def __init__(self,values,prevs = [],op = None,label = None):\n",
    "        self.values = values\n",
    "        self.prevs = prevs\n",
    "        self.op = op\n",
    "        self.label = label\n",
    "        self.grad = 0.0\n",
    "        self.grad_wrt = {}\n",
    "\n",
    "    def __add__(self,other):\n",
    "        # 定义加法运算\n",
    "        values = self.values + other.values\n",
    "        output = ScalarTmp(values,prevs=[self,other],op=\"+\")\n",
    "        output.grad_wrt[self] = 1\n",
    "        output.grad_wrt[other] = 1\n",
    "        return output\n",
    "\n",
    "    def __mul__(self,other):\n",
    "        # 定义乘法法运算\n",
    "        values = self.values * other.values\n",
    "        output = ScalarTmp(values,prevs=[self,other],op=\"*\")\n",
    "        output.grad_wrt[self] = other.values\n",
    "        output.grad_wrt[other] = self.values\n",
    "        return output\n",
    "    def __repr__(self):\n",
    "        # 打印类的信息\n",
    "        return f'{self.values} | {self.op} | {self.label}'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76eea63c-a26c-4e1b-b31c-0bb8e7a7299d",
   "metadata": {},
   "source": [
    "## 实现拓扑排序\n",
    "本质是深度优先搜索算法,这需要结合一个队列"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "fbb7d867-def2-4807-a158-97387bf1c8ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _top_order(root):\n",
    "    ordered,visited = [] ,set()\n",
    "    def _add_prevs(node):\n",
    "        visited.add(node)\n",
    "        for prev in node.prevs:\n",
    "            _add_prevs(prev)\n",
    "        ordered.append(node)\n",
    "    _add_prevs(root)\n",
    "    return ordered"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd1c0d6d-d590-46e5-9db6-c8cdff10fa7a",
   "metadata": {},
   "source": [
    "## 测试拓扑排序"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "7781f1be-ba9c-440b-bb33-dfc440b0efe3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[12.0 | * | None,\n",
       " 4.0 | * | None,\n",
       " 1.0 | None | a,\n",
       " 4.0 | None | c,\n",
       " 3.0 | + | None,\n",
       " 2.0 | None | b,\n",
       " 1.0 | None | a]"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_top_order(f)[::-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "7e8563ad-ea84-4069-aa42-7334f80b3480",
   "metadata": {},
   "outputs": [],
   "source": [
    "def backward(root):\n",
    "    # 定义顶点的梯度等于1\n",
    "    root.grad = 1.0\n",
    "    # 遍历计算图\n",
    "    ordered = _top_order(root)[::-1]\n",
    "    for node in ordered:\n",
    "        for v in node.prevs:\n",
    "            v.grad += node.grad * node.grad_wrt[v]\n",
    "    return root\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "21163116-e724-4f3a-a651-80e2bb3ee1f7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.50.0 (0)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"302pt\" height=\"290pt\"\n",
       " viewBox=\"0.00 0.00 302.00 290.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 286)\">\n",
       "<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-286 298,-286 298,4 -4,4\"/>\n",
       "<!-- 2311754325008 -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>2311754325008</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"0,-0.5 0,-69.5 86,-69.5 86,-0.5 0,-0.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"43\" y=\"-54.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad = 3.0</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"0,-46.5 86,-46.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"43\" y=\"-31.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 4.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"0,-23.5 86,-23.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"43\" y=\"-8.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">c</text>\n",
       "</g>\n",
       "<!-- 2311753259088 -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>2311753259088</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"52,-106.5 52,-175.5 138,-175.5 138,-106.5 52,-106.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"95\" y=\"-160.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad = 3.0</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"52,-152.5 138,-152.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"95\" y=\"-137.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 4.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"52,-129.5 138,-129.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"95\" y=\"-114.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">*</text>\n",
       "</g>\n",
       "<!-- 2311754325008&#45;&gt;2311753259088 -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>2311754325008&#45;&gt;2311753259088</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M59.88,-69.77C64.29,-78.57 69.09,-88.18 73.69,-97.38\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"70.6,-99.02 78.2,-106.4 76.86,-95.89 70.6,-99.02\"/>\n",
       "</g>\n",
       "<!-- 2311753260816 -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>2311753260816</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"101,-212.5 101,-281.5 193,-281.5 193,-212.5 101,-212.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"147\" y=\"-266.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad = 1.0</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"101,-258.5 193,-258.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"147\" y=\"-243.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 12.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"101,-235.5 193,-235.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"147\" y=\"-220.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">*</text>\n",
       "</g>\n",
       "<!-- 2311753259088&#45;&gt;2311753260816 -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>2311753259088&#45;&gt;2311753260816</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M111.88,-175.77C116.29,-184.57 121.09,-194.18 125.69,-203.38\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"122.6,-205.02 130.2,-212.4 128.86,-201.89 122.6,-205.02\"/>\n",
       "</g>\n",
       "<!-- 2311752781008 -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>2311752781008</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"104,-0.5 104,-69.5 190,-69.5 190,-0.5 104,-0.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"147\" y=\"-54.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad = 16.0</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"104,-46.5 190,-46.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"147\" y=\"-31.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"104,-23.5 190,-23.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"147\" y=\"-8.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">a</text>\n",
       "</g>\n",
       "<!-- 2311752781008&#45;&gt;2311753259088 -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>2311752781008&#45;&gt;2311753259088</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M130.12,-69.77C125.71,-78.57 120.91,-88.18 116.31,-97.38\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"113.14,-95.89 111.8,-106.4 119.4,-99.02 113.14,-95.89\"/>\n",
       "</g>\n",
       "<!-- 2311753253712 -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>2311753253712</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"156,-106.5 156,-175.5 242,-175.5 242,-106.5 156,-106.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"199\" y=\"-160.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad = 4.0</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"156,-152.5 242,-152.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"199\" y=\"-137.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 3.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"156,-129.5 242,-129.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"199\" y=\"-114.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">+</text>\n",
       "</g>\n",
       "<!-- 2311752781008&#45;&gt;2311753253712 -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>2311752781008&#45;&gt;2311753253712</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M163.88,-69.77C168.29,-78.57 173.09,-88.18 177.69,-97.38\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"174.6,-99.02 182.2,-106.4 180.86,-95.89 174.6,-99.02\"/>\n",
       "</g>\n",
       "<!-- 2311754328784 -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>2311754328784</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"208,-0.5 208,-69.5 294,-69.5 294,-0.5 208,-0.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"251\" y=\"-54.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad = 4.0</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"208,-46.5 294,-46.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"251\" y=\"-31.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 2.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"208,-23.5 294,-23.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"251\" y=\"-8.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">b</text>\n",
       "</g>\n",
       "<!-- 2311754328784&#45;&gt;2311753253712 -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>2311754328784&#45;&gt;2311753253712</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M234.12,-69.77C229.71,-78.57 224.91,-88.18 220.31,-97.38\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"217.14,-95.89 215.8,-106.4 223.4,-99.02 217.14,-95.89\"/>\n",
       "</g>\n",
       "<!-- 2311753253712&#45;&gt;2311753260816 -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>2311753253712&#45;&gt;2311753260816</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M182.12,-175.77C177.71,-184.57 172.91,-194.18 168.31,-203.38\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"165.14,-201.89 163.8,-212.4 171.4,-205.02 165.14,-201.89\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x21a3e32b0d0>"
      ]
     },
     "execution_count": 88,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a=ScalarTmp(1.0,label='a')\n",
    "b=ScalarTmp(2.0,label='b')\n",
    "c=ScalarTmp(4.0,label='c')\n",
    "d = a + b\n",
    "e = c * a\n",
    "f = d*e\n",
    "backward(f)\n",
    "draw_graph(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "id": "620ac0bf-6600-4f1f-a2e2-8863b6e8b61e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.50.0 (0)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"304pt\" height=\"290pt\"\n",
       " viewBox=\"0.00 0.00 304.00 290.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 286)\">\n",
       "<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-286 300,-286 300,4 -4,4\"/>\n",
       "<!-- 2311754325008 -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>2311754325008</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"0,-0.5 0,-69.5 86,-69.5 86,-0.5 0,-0.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"43\" y=\"-54.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad = 45.0</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"0,-46.5 86,-46.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"43\" y=\"-31.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 4.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"0,-23.5 86,-23.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"43\" y=\"-8.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">c</text>\n",
       "</g>\n",
       "<!-- 2311753259088 -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>2311753259088</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"53,-106.5 53,-175.5 139,-175.5 139,-106.5 53,-106.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"96\" y=\"-160.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad = 15.0</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"53,-152.5 139,-152.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"96\" y=\"-137.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 4.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"53,-129.5 139,-129.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"96\" y=\"-114.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">*</text>\n",
       "</g>\n",
       "<!-- 2311754325008&#45;&gt;2311753259088 -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>2311754325008&#45;&gt;2311753259088</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M60.21,-69.77C64.7,-78.57 69.59,-88.18 74.28,-97.38\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"71.22,-99.08 78.88,-106.4 77.46,-95.9 71.22,-99.08\"/>\n",
       "</g>\n",
       "<!-- 2311753260816 -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>2311753260816</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"102,-212.5 102,-281.5 194,-281.5 194,-212.5 102,-212.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"148\" y=\"-266.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad = 1.0</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"102,-258.5 194,-258.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"148\" y=\"-243.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 12.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"102,-235.5 194,-235.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"148\" y=\"-220.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">*</text>\n",
       "</g>\n",
       "<!-- 2311753259088&#45;&gt;2311753260816 -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>2311753259088&#45;&gt;2311753260816</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M112.88,-175.77C117.29,-184.57 122.09,-194.18 126.69,-203.38\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"123.6,-205.02 131.2,-212.4 129.86,-201.89 123.6,-205.02\"/>\n",
       "</g>\n",
       "<!-- 2311752781008 -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>2311752781008</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"104.5,-0.5 104.5,-69.5 191.5,-69.5 191.5,-0.5 104.5,-0.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"148\" y=\"-54.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad = 240.0</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"104.5,-46.5 191.5,-46.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"148\" y=\"-31.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"104.5,-23.5 191.5,-23.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"148\" y=\"-8.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">a</text>\n",
       "</g>\n",
       "<!-- 2311752781008&#45;&gt;2311753259088 -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>2311752781008&#45;&gt;2311753259088</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M131.12,-69.77C126.71,-78.57 121.91,-88.18 117.31,-97.38\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"114.14,-95.89 112.8,-106.4 120.4,-99.02 114.14,-95.89\"/>\n",
       "</g>\n",
       "<!-- 2311753253712 -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>2311753253712</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"157,-106.5 157,-175.5 243,-175.5 243,-106.5 157,-106.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"200\" y=\"-160.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad = 20.0</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"157,-152.5 243,-152.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"200\" y=\"-137.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 3.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"157,-129.5 243,-129.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"200\" y=\"-114.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">+</text>\n",
       "</g>\n",
       "<!-- 2311752781008&#45;&gt;2311753253712 -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>2311752781008&#45;&gt;2311753253712</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M164.88,-69.77C169.29,-78.57 174.09,-88.18 178.69,-97.38\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"175.6,-99.02 183.2,-106.4 181.86,-95.89 175.6,-99.02\"/>\n",
       "</g>\n",
       "<!-- 2311754328784 -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>2311754328784</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"210,-0.5 210,-69.5 296,-69.5 296,-0.5 210,-0.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"253\" y=\"-54.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad = 60.0</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"210,-46.5 296,-46.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"253\" y=\"-31.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 2.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"210,-23.5 296,-23.5 \"/>\n",
       "<text text-anchor=\"middle\" x=\"253\" y=\"-8.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">b</text>\n",
       "</g>\n",
       "<!-- 2311754328784&#45;&gt;2311753253712 -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>2311754328784&#45;&gt;2311753253712</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M235.79,-69.77C231.3,-78.57 226.41,-88.18 221.72,-97.38\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"218.54,-95.9 217.12,-106.4 224.78,-99.08 218.54,-95.9\"/>\n",
       "</g>\n",
       "<!-- 2311753253712&#45;&gt;2311753260816 -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>2311753253712&#45;&gt;2311753260816</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M183.12,-175.77C178.71,-184.57 173.91,-194.18 169.31,-203.38\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"166.14,-201.89 164.8,-212.4 172.4,-205.02 166.14,-201.89\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x21a3df15ed0>"
      ]
     },
     "execution_count": 93,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "backward(f)\n",
    "draw_graph(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0b0185d-b748-41ac-a105-49c22e27215f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
